#!/usr/bin/env python3
"""

   frida-fuzzer - fuzzer driver
   ----------------------------

   Written and maintained by Andrea Fioraldi <andreafioraldi@gmail.com>
   Based on American Fuzzy Lop by Michal Zalewski

   Copyright 2019 Andrea Fioraldi. All rights reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at:

     http://www.apache.org/licenses/LICENSE-2.0

"""

__version__ = "1.4"

import frida
import base64
import os
import sys
import time
import signal
import argparse
import tempfile
import random

TERM_HOME = "\x1b[H"
TERM_CLEAR = TERM_HOME + "\x1b[2J"
FUZZER_NAME = "frida-fuzzer %s" % __version__

SPLICE_CYCLES = 15 # warning: must be consistent wiht config.js! 

UNINFORMED_SEED = b"0" * 4

DESCR = """Frida API Fuzzer [%s]
Copyright (C) 2019 Andrea Fioraldi <andreafioraldi@gmail.com>
""" % __version__

def readable_time(t):
    t /= 1000 # ms
    h = t // 60 // 60
    m = t // 60 - h * 60
    s = t - m * 60 - h * 60 * 60
    return "%dh-%dm-%ds" % (h, m, s)

def get_cur_time(): # ms
    return int(round(time.time() * 1000))

opt = argparse.ArgumentParser(description=DESCR, formatter_class=argparse.RawTextHelpFormatter)
opt.add_argument("-i", action="store", help="Folder with initial seeds")
opt.add_argument("-o", action="store", help="Output folder with intermediate seeds and crashes")
opt.add_argument("-U", action="store_true", help="Connect to USB")
opt.add_argument("-spawn", action="store_true", help="Spawn instead of attach")
opt.add_argument("-script", action="store", default="fuzzer-agent.js", help="Script filename (default is fuzzer-agent.js)")
opt.add_argument('target', nargs=argparse.REMAINDER, help="Target program/pid (and arguments if spwaning)")

args = opt.parse_args()

if len(args.target) == 0:
    print (" >> Target not specified!")
    exit (1)

if args.o is None:
    output_folder = tempfile.mkdtemp(prefix="frida_fuzz_out_")
    print (" >> Temporary output folder :", output_folder)
else:
    output_folder = args.o
    if os.path.exists(output_folder):
        print (" >> %s already exists!" % output_folder)
        exit (1)
    os.mkdir(output_folder)

if args.i and not os.path.exists(args.i):
    print (" >> %s doesn't exists!")
    exit (1)

app_name = args.target[0]
try:
    app_name = int(app_name)
    pid = app_name
except:
    pass # not a PID

with open(args.script) as f:
    code = f.read()

if args.U:
    device = frida.get_usb_device()
    if args.spawn:
        pid = device.spawn(args.target)
        session = device.attach(pid)
    else:
        session = device.attach(app_name)
else:
    if args.spawn:
        if os.getenv("FRIDA_FUZZER_CHILD_OUT"):
            pid = frida.spawn(args.target)
        else:
            pid = frida.spawn(args.target, stdio="pipe")
        session = frida.attach(pid)
    else:
        session = frida.attach(app_name)

script = session.create_script(code, runtime="v8")

last_path = 0
start_time = 0
last_execs = None
last_ms = 0
eps = 0

messages = 0

def locate_diffs(a, b):
    f_loc = None
    l_loc = None
    for i in range(min(len(a), len(b))):
        if a[i] != b[i]:
            if f_loc is None: f_loc = i
            l_loc = i
    return f_loc, l_loc

class QEntry(object):
    def __init__(self):
        self.filename = ""
        self.size = 0
        self.num = 0
        self.was_fuzzed = False
        self.exec_ms = 0
        self.time = 0
        self.new_cov = False
        self.next = None

class Queue(object):
    def __init__(self):
        self.max_num = 0
        self.start = None
        self.cur = None
        self.top = None
        # use a dict for fast lookup by num
        self.dict = {}

    def add(self, buf, exec_ms, new_cov, stage, num):
        q = QEntry()
        q.filename = os.path.join(output_folder, "id_%d_%s" % (num, stage))
        if new_cov:
            q.filename += "_cov"
        q.num = num
        q.exec_ms = exec_ms
        q.new_cov = new_cov
        q.time = get_cur_time()
        q.size = len(buf)
        with open(q.filename, "wb") as f:
            f.write(buf)
        self.max_num = max(num, self.max_num)
        self.dict[num] = q
        if self.top:
            self.top.next = q
            self.top = q
        else:
            self.start = q
            self.top = q
    
    def get(self):
        if self.cur is None:
            self.cur = self.start
        elif self.cur.next is None:
            self.cur = self.start
        else:
            q = self.cur.next
            self.cur = q
        return self.cur
    
    def find_by_num(self, num):
        '''q = self.start
        while q is not None:
            if q.num == num:
                return q
            q = q.next
        return None'''
        return self.dict.get(num, None)
    
    def get_splice_target(self, q, buf):
        tid = random.randint(0, self.max_num)
        t = self.find_by_num(tid)
        while t is not None and (t.size < 2 or t.num == q.num):
            t = t.next
        if t is None:
            return None
        with open(t.filename, "rb") as f:
            new_buf = f.read()
        f_diff, l_diff = locate_diffs(buf, new_buf)
        if f_diff is None or l_diff < 2 or f_diff == l_diff:
            return None
        split_at = random.randint(f_diff, l_diff -1)
        return buf[:split_at] + new_buf[split_at:]
        

class InitialSeeds(object):
    def __init__(self):
        self.idx = 0
        self.seeds = []

    def add(self, filename):
        self.seeds.append(filename)
    
    def uninformed(self):
        filename = os.path.join(output_folder, "uninformed")
        with open(filename, "wb") as f:
            f.write(UNINFORMED_SEED)
        self.add(filename)
    
    def get(self):
        if self.idx < len(self.seeds):
            r = self.seeds[self.idx]
            self.idx += 1
            return r
        return None


initial = InitialSeeds()
queue = Queue()

def status_screen(status):
    global queue, app_name, messages, last_execs, last_ms, eps
    #return

    cur_ms = get_cur_time()
    if cur_ms - last_ms < 200:
        return
    eps_total = float(status["total_execs"]) * 1000 / (cur_ms - start_time)
    if last_execs == None:
        eps = eps_total
    else:
        cur_eps = float(status["total_execs"] - last_execs) * 1000 / (cur_ms - last_ms)
        if cur_eps * 5 < eps or cur_eps / 5 > eps:
            eps = cur_eps;
        eps = eps * (1.0 - 1.0 / 16) + cur_eps * (1.0 / 16)
    last_execs = status["total_execs"]
    last_ms = cur_ms

    field_len = len(" target app  ")
    half = max(len(str(app_name)), len(output_folder), 18) +2
    half = len(" \u2551 target app   \u2502 ") + half - len(FUZZER_NAME)
    half = half // 2 if half % 2 == 0 else (half +1) // 2
    banner = " "*half + (FUZZER_NAME) + " "*half
    boxlen = len(banner) -3
    def pprint(s, v=""):
        s = " " + s
        s = s + " " * (field_len - len(s))
        o = " \u2551" + s + "\u2502 " + str(v)
        o += " " * ((boxlen +2) - len(o))
        o += "\u2551"
        print (o)
    #print (TERM_HOME + " "*len(banner))
    print (TERM_CLEAR)
    print (banner)
    print (" \u2554" + "\u2550" * field_len + "\u2564" + "\u2550" * (boxlen - field_len -1) + "\u2557")

    pprint ("target", app_name)
    pprint ("execs", status["total_execs"])
    #pprint ("speed", "%d (%d) exec/s" % (int(eps), int(eps_total)))
    pprint ("speed", "%d exec/s" % int(eps))
    pprint ("uptime", readable_time(cur_ms - start_time))
    pprint ("last path", readable_time(cur_ms - last_path))
    pprint ("map density", "{0:.2f} %".format(status["map_rate"]))
    pprint ("current", status["cur"] if status["cur"] != -1 else "<init>")
    pprint ("queue size", queue.max_num +1)
    pprint ("favoreds", status["favs"])
    pprint ("pending fav", status["pending_fav"])
    pprint ("last stage", status["stage"])
    pprint ("output path", output_folder)
    #pprint("messages", messages)

    print (" \u255A" + "\u2550" * field_len + "\u2567" + "\u2550" * (boxlen - field_len -1) + "\u255D")
    print ()


def on_interesting(message, data):
    global queue, last_path
    exec_ms = message["exec_ms"]
    new_cov = message["new_cov"]
    stage = message["stage"]
    num = message["num"]
    last_path = get_cur_time()
    queue.add(data, exec_ms, new_cov, stage, num)

'''
def on_next(message, data): # dead code ATM
    global queue
    q = queue.get()
    with open(q.filename, "rb") as f:
        buf = f.read()
    script.post({
      "type": "input",
      "num": q.num,
      "buf": buf.hex(),
      "was_fuzzed": q.was_fuzzed,
    })
'''

def on_dry(message, data):
    global initial
    fname = initial.get()
    if fname is None:
        script.post({
          "type": "input",
          "buf": None,
        })
    else:
        print (" >> Dry run", fname)
        with open(fname, "rb") as f:
            buf = f.read()
        script.post({
          "type": "input",
          "num": initial.idx -1,
          "buf": buf.hex(),
        })

def on_get(message, data):
    global queue
    num = message["num"]
    q = queue.find_by_num(num)
    with open(q.filename, "rb") as f:
        buf = f.read()
    script.post({
      "type": "input",
      "num": q.num,
      "buf": buf.hex(),
    })

def on_splice(message, data):
    global queue
    num = message["num"]
    splice_cycle = message["cycle"]
    q = queue.find_by_num(num)
    with open(q.filename, "rb") as f:
        buf = f.read()
    new_buf = None
    while splice_cycle < SPLICE_CYCLES:
        splice_cycle += 1
        new_buf = queue.get_splice_target(q, buf)
        if new_buf is not None:
            break
    if new_buf is None:
        script.post({
          "type": "splice",
          "buf": None, # failed
          "cycle": splice_cycle,
        })
    else:
        script.post({
          "type": "splice",
          "buf": new_buf.hex(),
          "cycle": splice_cycle,
        })

def on_crash(message, data):
    global queue, script, session, pid
    print ("\n"*2 + "  ============= CRASH FOUND! =============")
    print ("    type:", message["err"]["type"])
    if "memory" in message["err"]:
        print ("    %s at:" % message["err"]["memory"]["operation"], message["err"]["memory"]["address"])
    print ("")
    t = int(time.time())
    name = os.path.join(output_folder, "crash_%s_%s_%d" % (message["stage"], message["err"]["type"], t))
    #name = os.path.join(output_folder, "crash_%d" % t)
    print (" >> Saving at %s" % repr(name))
    with open(name, "wb") as f:
        f.write(data)
    if args.spawn and not args.U:
        print (" >> Killing", pid)
        os.kill(pid, signal.SIGKILL)
    print (" >> Press Control-C to exit...")
    script.unload()
    session.detach()

def on_exception(message, data):
    global queue, script, session, pid
    print ("\n"*2 + "  =========== EXCEPTION FOUND! ===========")
    print ("    message:", message["err"])
    print ("")
    t = int(time.time())
    name = os.path.join(output_folder, "exception_%s_%d" % (message["stage"], t))
    #name = os.path.join(output_folder, "crash_%d" % t)
    print (" >> Saving at %s" % repr(name))
    with open(name, "wb") as f:
        f.write(data)
    if args.spawn and not args.U:
        print (" >> Killing", pid)
        os.kill(pid, signal.SIGKILL)
    print (" >> Press Control-C to exit...")
    script.unload()
    session.detach()

def on_stats(message, data):
    status_screen(message)

def report_error(message):
    print (" ============= FUZZER ERROR! =============")
    if "lineNumber" in message and message["lineNumber"] is not None:
        print ("  line %d: %s" % (message["lineNumber"], message["description"]))
    else:
        print ("  %s" % message["description"])
    if "stack" in message:
        print ("  JS stacktrace:\n")
        print (message["stack"])
    print ("")

def on_message(message, data):
    global messages, pid
    messages += 1
    if message["type"] == "error":
        report_error(message)
        if args.spawn and not args.U:
          print (" >> Killing", pid)
          os.kill(pid, signal.SIGKILL)
        print (" >> Press Control-C to exit...")
        script.unload()
        session.detach()
    msg = message["payload"]
    if msg["event"] == "interesting":
        on_interesting(msg, data)
        on_stats(msg, data)
    #elif msg["event"] == "next": # not used 
    #    on_next(msg, data)
    #    on_stats(msg, data)
    elif msg["event"] == "get":
        on_get(msg, data)
        on_stats(msg, data)
    elif msg["event"] == "dry":
        on_dry(msg, data)
        on_stats(msg, data)
    elif msg["event"] == "splice":
        on_splice(msg, data)
        on_stats(msg, data)
    elif msg["event"] == "crash":
        on_stats(msg, data)
        on_crash(msg, data)
    elif msg["event"] == "exception":
        on_stats(msg, data)
        on_exception(msg, data)
    elif msg["event"] in ("stats", "status"):
        on_stats(msg, data)

script.on('message', on_message)
script.load()

def signal_handler(sig, frame):
    global pid
    print (" >> Exiting...")
    if args.spawn and not args.U:
        print (" >> Killing", pid)
        os.kill(pid, signal.SIGKILL)
    try:
        script.unload()
        session.detach()
    except: pass
    os._exit (0)
signal.signal(signal.SIGINT, signal_handler)

if args.i is None:
    initial.uninformed()
else:
    for fname in os.listdir(args.i):
        p = os.path.join(args.i, fname)
        if os.path.isfile(p):
            initial.add(p)

print(TERM_CLEAR, end="")

start_time = get_cur_time()

last_path = start_time

try:
    script.exports.loop()
except (frida.core.RPCException, frida.InvalidOperationError) as e:
    try:
        print (e)
    except: pass
    exit (1)

sys.stdin.read()

